# -*- coding: utf-8 -*-
"""
Created on Mon Dec 28 13:46:04 2020

@author: 98330
"""

import numpy as np
import math
from ortools.graph import pywrapgraph

class MinCostFlowWrapper:
 
  def __init__(self, number_nodes):
    self.min_cost_flow = pywrapgraph.SimpleMinCostFlow()
    self.number_nodes = number_nodes
    self.supplies = [0 for i in range(number_nodes)]
    self.lower_bounds = []
    self.solution = [[0 for i in range(number_nodes)] for j in range(number_nodes)]
    self.start_nodes = []
    self.end_nodes = []
 
  def AddArcWithCapacityAndUnitCost(self, start_node, end_node, lower_bound,
                                     upper_bound, unit_cost):
    self.min_cost_flow.AddArcWithCapacityAndUnitCost(int(start_node), int(end_node),
                                     int(upper_bound - lower_bound), int(unit_cost))
    self.supplies[start_node] -= lower_bound
    self.supplies[end_node] += lower_bound
    self.lower_bounds.append(lower_bound)
    self.start_nodes.append(start_node)
    self.end_nodes.append(end_node)
      
  def Compute(self):
    for node in range(self.number_nodes):
      self.min_cost_flow.SetNodeSupply(int(node), int(self.supplies[node]))
    if self.min_cost_flow.Solve() == self.min_cost_flow.OPTIMAL:
      # print('Minimum cost:', self.min_cost_flow.OptimalCost())
      for link_index in range(self.min_cost_flow.NumArcs()):
        self.solution[self.start_nodes[link_index]][self.end_nodes[link_index]] = (
                self.min_cost_flow.Flow(link_index) + self.lower_bounds[link_index])
      
  def GetFlow(self, start_node, end_node):
    return self.solution[start_node][end_node]


def Floor(a):
  return np.floor(a + 10**(-4))


def Ceil(a):
  return np.ceil(a - 10**(-4))


def GetAccurateFlow(u,server_placement,load,final_optimal):
    """function to get the node capacity of the graph
    
    param:
    ------------
        u : the metrix to decide the proportion of the load to each base station
        server_placement : decide the number of servers assigned on each base station
        load : the workload of each base station during a period
        final_optimal : the final result after optimizing the target
    return
    ------------
        accurate_flow : the flow after adding the low bound
        list(node_capicity) : the node capacity of  each node
    """
    num_server = len(server_placement)
    number_nodes = 2 * len(server_placement) + 2
    source_index = 2 * len(server_placement)
    sink_index = 2 * len(server_placement) + 1
    graph = MinCostFlowWrapper(number_nodes)
    for node_index in range(num_server):
      graph.AddArcWithCapacityAndUnitCost(source_index, node_index, 
                                          Floor(server_placement[node_index]),
                                          Ceil(server_placement[node_index]),
                                          100 * (Ceil(server_placement[node_index]) - server_placement[node_index]))
    load_min = Floor(np.array(load) * final_optimal)
    load_max = Ceil(np.array(load) * final_optimal)
    for node_index in range(num_server):
       temp = load[node_index] * final_optimal
       if temp == 0:
           temp = 0.0001
       graph.AddArcWithCapacityAndUnitCost(num_server + node_index, sink_index, 
                                          load_min[node_index],
                                          load_max[node_index],
                                          100 * (temp) / 
                                          (temp - load_min[node_index])) 
       
    
    for node_index in range(num_server):
      u_list = np.array(u[node_index]) * load[node_index]
      for node_index2 in range(num_server):
        graph.AddArcWithCapacityAndUnitCost(node_index2, num_server + node_index, 
                                            Floor(u_list[node_index2]),
                                            Ceil(u_list[node_index2]),
                                            100 * (Ceil(u_list[node_index2]) - u_list[node_index2]))

    sum_server = sum(server_placement)
    graph.AddArcWithCapacityAndUnitCost(sink_index, source_index, Floor(sum_server), Ceil(sum_server), 0)
       
    graph.Compute()
    
    total_server_used = 0
    server_placement_list = []
    for node_index in range(num_server):
      total_server_used += graph.GetFlow(source_index, node_index)
      server_placement_list.append(graph.GetFlow(source_index, node_index))
      
    return total_server_used, server_placement_list
    
if __name__ == '__main__':
  graph = MinCostFlowWrapper(3)
  graph.AddArcWithCapacityAndUnitCost(0, 1, 2, 4, 5)
  graph.AddArcWithCapacityAndUnitCost(1, 2, 3, 5, -6)
  graph.AddArcWithCapacityAndUnitCost(2, 0, 1, 5, -5)
  graph.Compute()
  print(graph.GetFlow(0,1))
  